#!/usr/bin/python3

import argparse
import ipaddress
import json
import os
import shutil
import string
import subprocess
import sys
import time

BASE_ASN = 64600
NODE_MASK = 26

RACK_INFORMATION = '/etc/neco/rack'
CLUSTER_INFORMATION = '/etc/neco/cluster'
SABAKAN_IPAM = '/etc/neco/sabakan_ipam.json'

VM = subprocess.run(['systemd-detect-virt', '-v'],
                    stdout=subprocess.DEVNULL).returncode == 0
ETH0 = 'ens3' if VM else 'eno1'
ETH1 = 'ens4' if VM else 'eno2'


class Cluster:
    # pylint: disable=too-many-arguments
    def __init__(self, name: str, bastion_network: str, bmc_network: str,
                 ntp_servers: [str]):
        self._name = name
        self._bastion_network = ipaddress.IPv4Address(bastion_network)
        if not isinstance(ntp_servers, list):
            raise TypeError('ntp_servers must be a list')
        self._ntp_servers = ntp_servers
        self._bmc_network = ipaddress.IPv4Network(bmc_network)

    @property
    def name(self) -> str:
        return self._name

    @property
    def bastion_network(self) -> ipaddress.IPv4Address:
        return self._bastion_network

    @property
    def bmc_network(self) -> ipaddress.IPv4Network:
        return self._bmc_network

    @property
    def ntp_servers(self) -> [str]:
        return self._ntp_servers


def load_clusters(filename: str) -> [Cluster]:
    with open(filename) as f:
        j = json.load(f)
    if not isinstance(j, list):
        raise TypeError('invalid contents in ' + filename)
    return [Cluster(**d) for d in j]


def ask_cluster(clusters: [Cluster], name: str) -> Cluster:
    if name is not None:
        for c in clusters:
            if c.name == name:
                return c
        raise RuntimeError('no such cluster: ' + name)

    if len(clusters) == 1:
        return clusters[0]

    print('Choose the cluster of this server.\n')
    for i, c in enumerate(clusters):
        print('{}) {}'.format(i + 1, c.name))

    while True:
        answer = input("Input a number [1-{}]: ".format(len(clusters)))
        try:
            n = int(answer)
        except ValueError:
            continue

        if 1 <= n <= len(clusters):
            break

    return clusters[n - 1]


def dump_rack_information(rack: int):
    os.makedirs('/etc/neco', 0o755, exist_ok=True)
    write_file(RACK_INFORMATION, str(rack) + "\n")


def dump_cluster_information(c: Cluster):
    os.makedirs('/etc/neco', 0o755, exist_ok=True)
    write_file(CLUSTER_INFORMATION, c.name + "\n")


def dump_sabakan_ipam(bmc_network: ipaddress.IPv4Network):
    ipam = {
        "max-nodes-in-rack": 28,
        "node-ipv4-pool": "10.69.0.0/20",
        "node-ipv4-range-size": 6,
        "node-ipv4-range-mask": 26,
        "node-index-offset": 3,
        "node-ip-per-node": 3,
        "node-gateway-offset": 1,
        "bmc-ipv4-pool": str(bmc_network),
        "bmc-ipv4-offset": "0.0.1.0",
        "bmc-ipv4-range-size": 5,
        "bmc-ipv4-range-mask": bmc_network.prefixlen,
        "bmc-ipv4-gateway-offset": 1,
    }
    os.makedirs('/etc/neco', 0o755, exist_ok=True)
    write_file(SABAKAN_IPAM, json.dumps(ipam))


def systemctl(*args):
    cmd = ["systemctl"]
    cmd.extend(args)
    subprocess.run(cmd, check=True)


def write_file(path: str, content: str):
    with open(path, "w") as f:
        f.write(content)
        f.flush()
        os.fsync(f.fileno())


def node_addresses(rack: int) -> [ipaddress.IPv4Address]:
    base = ipaddress.ip_address('10.69.0.0')
    node0 = base + 192 * rack + 3
    return node0, node0 + 64, node0 + 128


def bastion_address(rack: int, c: Cluster):
    return c.bastion_network + rack


def wait_devs(*devs):
    wants = set(devs)

    while True:
        ret = subprocess.run(["ip", "-o", "address"],
                             stdout=subprocess.PIPE,
                             check=True)
        readies = set()
        for line in ret.stdout.decode('utf-8').split('\n'):
            cols = line.split()
            if len(cols) < 9:
                continue
            if cols[2] != 'inet':
                continue
            readies.add(cols[1])
        if wants.issubset(readies):
            return
        time.sleep(1)


def write_dummy_networkd(name: str, addr: ipaddress.IPv4Address):

    netdev_content = """\
[NetDev]
Name={}
Kind=dummy
""".format(name)

    network_content = """\
[Match]
Name={}

[Network]
Address={}/32
""".format(name, addr)

    write_file("/etc/systemd/network/10-{}.netdev".format(name),
               netdev_content)
    write_file("/etc/systemd/network/10-{}.network".format(name),
               network_content)


def write_link_networkd(name: str, addr: ipaddress.IPv4Address):
    content = """\
[Match]
Name={}

[Network]
LLDP=true
EmitLLDP=nearest-bridge

[Address]
Address={}/26
Scope=link
""".format(name, addr)

    write_file("/etc/systemd/network/10-{}.network".format(name), content)


def configure_disable_offload():
    content = """\
[Unit]
Description=Disable network device offload
Wants=network-online.target
After=network-online.target
ConditionVirtualization=!kvm

[Service]
Type=oneshot
ExecStart=/sbin/ethtool -K {} tx off rx off
ExecStart=/sbin/ethtool -K {} tx off rx off
RemainAfterExit=yes

[Install]
WantedBy=multi-user.target
""".format(ETH0, ETH1)

    write_file("/etc/systemd/system/disable-offload.service", content)

    systemctl("daemon-reload")
    systemctl("enable", "disable-offload.service")
    systemctl("start", "disable-offload.service")
    print("configured disable offload")


def configure_default_route(tor1: ipaddress.IPv4Address,
                            tor2: ipaddress.IPv4Address,
                            bastion: ipaddress.IPv4Address):
    content = """\
[Unit]
Wants=network-online.target
After=network-online.target

[Service]
Type=oneshot
ExecStart=/bin/ip route add 0.0.0.0/0 src {} nexthop via {} nexthop via {}
RemainAfterExit=yes
FailureAction=reboot-immediate

[Install]
WantedBy=multi-user.target
""".format(bastion, tor1, tor2)

    write_file("/etc/systemd/system/setup-route.service", content)

    systemctl("daemon-reload")
    systemctl("enable", "setup-route.service")
    systemctl("start", "setup-route.service")
    print("configured default gateway")


def configure_bird(asn: int, tor1: ipaddress.IPv4Address,
                   tor2: ipaddress.IPv4Address):
    with open("/extras/setup/bird.conf.template", "r") as f:
        tmpl = string.Template(f.read())

    content = tmpl.substitute(
        ASN=asn, TOR1=str(tor1), TOR2=str(tor2), MASK=NODE_MASK)
    os.makedirs("/etc/bird", mode=0o755, exist_ok=True)
    write_file("/etc/bird/bird.conf", content)

    shutil.copyfile("/extras/setup/bird.service",
                    "/etc/systemd/system/bird.service")
    systemctl("daemon-reload")
    systemctl("enable", "bird.service")
    systemctl("start", "bird.service")
    print("started bird.service")


def configure_chrony(c: Cluster):
    systemctl('--now', 'mask', 'systemd-timesyncd.service')

    with open("/extras/setup/chrony.conf.template", "r") as f:
        tmpl = string.Template(f.read())

    servers = '\n'.join('server {} iburst'.format(ip) for ip in c.ntp_servers)
    content = tmpl.substitute(ntp_servers=servers)
    write_file("/etc/chrony.conf", content)

    shutil.copyfile("/extras/setup/chrony-wait", "/usr/local/bin/chrony-wait")
    os.chmod("/usr/local/bin/chrony-wait", 0o755)
    shutil.copyfile("/extras/setup/chrony-wait.service",
                    "/etc/systemd/system/chrony-wait.service")
    shutil.copyfile("/extras/setup/chronyd.service",
                    "/etc/systemd/system/chronyd.service")

    systemctl("daemon-reload")
    systemctl("enable", "chrony-wait.service")
    systemctl("enable", "chronyd.service")
    systemctl("start", "chronyd.service")
    print("started chronyd.service")


def enable_networkd():
    subprocess.run(["apt-get", "-y", "purge", "netplan.io"], check=True)
    shutil.rmtree("/etc/netplan", ignore_errors=True)
    systemctl("enable", "systemd-networkd")  # by default, it is disabled.
    systemctl("restart", "systemd-networkd")
    print("purged netplan, enabled systemd-networkd")


def configure_dns(servers: [str]):
    os.unlink('/etc/resolv.conf')
    with open('/etc/resolv.conf', 'w') as f:
        for a in servers:
            print('nameserver', a, file=f, flush=True)
        f.flush()
        os.fsync(f.fileno())


def configure(rack: int, dns_servers: [str], c: Cluster):
    node0, node1, node2 = node_addresses(rack)
    tor1 = ipaddress.IPv4Network(
        (node1, NODE_MASK), strict=False).network_address + 1
    tor2 = ipaddress.IPv4Network(
        (node2, NODE_MASK), strict=False).network_address + 1
    bastion = bastion_address(rack, c)

    write_dummy_networkd("node0", node0)
    write_link_networkd(ETH0, node1)
    write_link_networkd(ETH1, node2)
    write_dummy_networkd("bastion", bastion)
    enable_networkd()

    wait_devs(ETH0, ETH1, "bastion")
    print("configured network devices")

    configure_disable_offload()
    if dns_servers is not None:
        configure_dns(dns_servers)
    configure_default_route(tor1, tor2, bastion)
    configure_bird(BASE_ASN + rack, tor1, tor2)
    configure_chrony(c)

    dump_rack_information(rack)
    dump_cluster_information(c)
    dump_sabakan_ipam(c.bmc_network)


def main():
    parser = argparse.ArgumentParser(description='Setup network script')
    parser.add_argument(
        '--dns',
        dest='dns_servers',
        metavar='IP',
        action='append',
        help='dns server address')
    parser.add_argument('-c', '--cluster', help='cluster name')
    parser.add_argument('rack', type=int, help='rack number')

    args = parser.parse_args()
    rack = args.rack
    if rack < 0:
        sys.exit("invalid rack number")

    if os.getuid() != 0:
        sys.exit("run as root")

    clusters = load_clusters('cluster.json')
    c = ask_cluster(clusters, args.cluster)

    configure(rack, args.dns_servers, c)


if __name__ == '__main__':
    main()
